from BasicLibraries import *
import utils_nc as ut

'''
This is the function to make the dataset in the right shape for the regressions
'''


#%%
class BehavioralEffect():
    def __init__(self):
        pass
    def get_vars_and_dummies(self, X, region_var, emissionVar, initiatives, envSDGs, YearType, analysis_type = 'Paris'):
        dt = X.copy()
        dt = dt[dt.sale_usd>0]
        dt = dt[dt.Total_invested_capital_BS1>0]
        #=================================================
        if initiatives == 'all':
            tmK = ut.utils().return_Initiatives()
            keysToRun = ut.utils().get_initiativeKeys(tmK, envSDGs)
            
            dt['Environmental SDGs'] = dt[keysToRun].sum(axis = 1)
            dt = dt.sort_values(by = ['gvkey', 'rfyear']).reset_index()
            dt['Environmental SDGs'] = dt[['Environmental SDGs', 'gvkey']].groupby('gvkey').rolling(2).mean().values.ravel().tolist()
            dt = dt.set_index('mrg')
                          

        else:
            tmK = initiatives
            keysToRun = ut.utils().get_initiativeKeys(tmK, envSDGs)
            dt['Environmental SDGs'] =  dt[keysToRun].sum(axis = 1)
            dt = dt.sort_values(by = ['gvkey', 'rfyear']).reset_index()
            dt['Environmental SDGs'] = dt[['Environmental SDGs', 'gvkey']].groupby('gvkey').rolling(2).mean().values.ravel().tolist()
            dt = dt.set_index('mrg')

            #=================================================

        ### Get dummies
        g_dummies = pd.get_dummies(dt[region_var], drop_first = True).astype(int)
        s_dummies = pd.get_dummies(dt['GICS_level_1'], drop_first = True).astype(int)
        y_dummies = pd.get_dummies(dt[YearType], drop_first = True)

        ### sort data
        dt = dt.sort_values(by = ['gvkey', YearType])
        return dt, g_dummies, s_dummies, y_dummies,   tmK
        


    def make_vars_continuous(self, L, YearType, vars_, g_dummies, s_dummies, y_dummies):
        #==============
        dx = L.copy()    
        dx = dx.sort_values(by = ['gvkey', YearType])
        dx = pd.concat((dx[[ 'gvkey', YearType] + vars_], g_dummies, s_dummies, y_dummies ), axis = 1).dropna()
        
       
        dx = dx.reset_index()
        #============== Clean the data
        dx = dx.replace([-np.inf, np.inf], [np.nan, np.nan]).dropna().reset_index(drop = True)
        zeroSum = dx.sum()
        zeroSum = zeroSum[zeroSum !=0].index
        dx = dx[zeroSum]
        g_dummies = g_dummies[list(set(zeroSum) & set(g_dummies.columns))]
        s_dummies = s_dummies[list(set(zeroSum) & set(s_dummies.columns))]
        y_dummies = y_dummies[list(set(zeroSum) & set(y_dummies.columns))]
        #============== Get the dependent and independent variables
        X  = dx[['mrg', 'gvkey'] + vars_ + list(g_dummies.columns) + list(s_dummies.columns) + list(y_dummies.columns) ]
        
        return X, g_dummies, s_dummies, y_dummies,  dx


    def make_vars(self, L, YearType, vars_, g_dummies, s_dummies, y_dummies):
        #==============
        dx = L.copy()    
        dx = dx.sort_values(by = ['gvkey', YearType])
        dx = pd.concat((dx[[ 'gvkey', YearType] + vars_], g_dummies, s_dummies, y_dummies ), axis = 1).dropna()
        
               
        dx = dx.reset_index()
        #============== Clean the data
        dx = dx.replace([-np.inf, np.inf], [np.nan, np.nan]).dropna().reset_index(drop = True)
        #====
        idx = dx.groupby('gvkey').count()
        idx = idx[idx> 1].dropna().index
        dx = dx[dx.gvkey.isin(idx)]

        zeroSum = dx.sum()
        zeroSum = zeroSum[zeroSum != 0].index
        dx = dx[zeroSum]
        g_dummies = g_dummies[list(set(zeroSum) & set(g_dummies.columns))]
        s_dummies = s_dummies[list(set(zeroSum) & set(s_dummies.columns))]
        y_dummies = y_dummies[list(set(zeroSum) & set(y_dummies.columns))]
        

        
        
        #============== Get the dependent and independent variables
        X =  dx[['mrg', 'rfyear', 'gvkey'] + vars_ + list(g_dummies.columns) + list(s_dummies.columns) + list(y_dummies.columns) ]
        
        return  X, g_dummies, s_dummies, y_dummies, dx



    
        
